{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "deletable": true,
    "editable": true
   },
   "source": [
    "# Simple Reinforcement Learning with Tensorflow Part 4: Deep Q-Networks and Beyond\n",
    "\n",
    "In this iPython notebook I implement a Deep Q-Network using both Double DQN and Dueling DQN. The agent learn to solve a navigation task in a basic grid world. To learn more, read here: https://medium.com/p/8438a3e2b8df\n",
    "\n",
    "For more reinforcment learning tutorials, see:\n",
    "https://github.com/awjuliani/DeepRL-Agents"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "from __future__ import division\n",
    "\n",
    "import gym\n",
    "import numpy as np\n",
    "import random\n",
    "import tensorflow as tf\n",
    "import tensorflow.contrib.slim as slim\n",
    "import matplotlib.pyplot as plt\n",
    "import scipy.misc\n",
    "import os\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "deletable": true,
    "editable": true
   },
   "source": [
    "### Load the game environment"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "deletable": true,
    "editable": true
   },
   "source": [
    "Feel free to adjust the size of the gridworld. Making it smaller provides an easier task for our DQN agent, while making the world larger increases the challenge."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true,
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAC9xJREFUeJzt3V+sZWV9xvHv0wFEoRUQSigDPXNBMMSEwU4oFNNYYAxS\ng70ikNCYhoQb2w6NiZH2gnjHRWP0ojEhoiWVYilCJcRgUTFNk2Zk+FMLDAjiIEPBGWwtlia26K8X\ne6GHCTOzzpx99jmL3/eTnJy93r0na70z8+y1zpo975OqQlI/v7LeByBpfRh+qSnDLzVl+KWmDL/U\nlOGXmjL8UlOrCn+Sy5I8leSZJJ+Y10FJWns50g/5JNkEfBfYDuwFHgSurqon5nd4ktbKUav4tecD\nz1TVswBJvgR8GDho+E8++eRaWlpaxS4lHcqePXt4+eWXM+a1qwn/6cDzy7b3Ar99qF+wtLTErl27\nVrFLSYeybdu20a9d8xt+Sa5LsivJrv3796/17iSNtJrwvwCcsWx78zD2BlV1c1Vtq6ptp5xyyip2\nJ2meVhP+B4GzkmxJcgxwFXDPfA5L0lo74p/5q+q1JH8MfA3YBHy+qh6f25FJWlOrueFHVX0V+Oqc\njkXSAvkJP6kpwy81Zfilpgy/1JThl5oy/FJThl9qyvBLTRl+qSnDLzVl+KWmDL/UlOGXmjL8UlOG\nX2rK8EtNGX6pqcOGP8nnk+xL8tiysZOS3J/k6eH7iWt7mJLmbcyZ/6+Byw4Y+wTwjao6C/jGsC1p\nQg4b/qr6J+A/Dhj+MHDr8PhW4A/mfFyS1tiR/sx/alW9ODx+CTh1TscjaUFWfcOvZk2fB237tLFH\n2piONPw/THIawPB938FeaGOPtDEdafjvAT4yPP4I8JX5HI6kRTlsaUeS24H3Aycn2QvcCNwE3JHk\nWuA54Mq1PMi5yKjW4rXZ9brteQM46A+Eb201gYkfNvxVdfVBnrpkzsciaYH8hJ/UlOGXmjL8UlOG\nX2rK8EtNGX6pKcMvNWX4paYMv9SU4ZeaMvxSU4ZfasrwS00Zfqkpwy81Zfilpgy/1NSYxp4zkjyQ\n5IkkjyfZMYzb2iNN2Jgz/2vAx6rqHOAC4KNJzsHWHmnSxjT2vFhVDw+PfwLsBk7H1h5p0lb0M3+S\nJeA8YCcjW3ss7ZA2ptHhT3I88GXg+qp6Zflzh2rtsbRD2phGhT/J0cyCf1tV3TUMj27tkbTxjLnb\nH+AWYHdVfWrZU7b2SBN22NIO4CLgD4F/S/LoMPbnTLG1R9IvjGns+WcO3jhla480UX7CT2rK8EtN\nGX6pqTE3/N4SWtdkr6euv/Ebv6HbM7/UleGXmjL8UlOGX2rK8EtNGX6pKcMvNWX4paYMv9SU4Zea\nMvxSU4ZfasrwS02NWcPv2CTfTvKvQ2PPJ4dxG3ukCRtz5v8pcHFVnQtsBS5LcgE29kiTNqaxp6rq\nv4fNo4evwsYeadLGrtu/aVi5dx9wf1XZ2CNN3KjwV9XPqmorsBk4P8l7Dnjexh5pYlZ0t7+qfgw8\nAFyGjT3SpI25239KkhOGx28HtgNPYmOPNGljFvA8Dbg1ySZmbxZ3VNW9Sf4FG3ukyRrT2PMdZrXc\nB47/CBt7pMnyE35SU4ZfasrwS00Zfqkpwy81Zfilpgy/1JThl5oy/FJThl9qyvBLTRl+qSnDLzVl\n+KWmDL/UlOGXmjL8UlOjwz8s3/1IknuHbRt7pAlbyZl/B7B72baNPdKEjS3t2Az8PvC5ZcM29kgT\nNvbM/2ng48DPl43Z2CNN2Jh1+z8E7Kuqhw72Ght7pOkZs27/RcAVSS4HjgV+LckXGRp7qupFG3uk\n6RnT0ntDVW2uqiXgKuCbVXUNNvZIk7aaf+e/Cdie5Gng0mFb0kSMuez/har6FvCt4bGNPdKE+Qk/\nqSnDLzVl+KWmDL/UlOGXmjL8UlOGX2rK8EtNGX6pKcMvNWX4paYMv9SU4ZeaMvxSUyv6L73T9qar\njC1I1nHf62w9f9t1SJ75paZGnfmT7AF+AvwMeK2qtiU5Cfg7YAnYA1xZVf+5Nocpad5Wcub/vara\nWlXbhm1LO6QJW81lv6Ud0oSNDX8BX0/yUJLrhrFRpR2SNqaxd/vfV1UvJPl14P4kTy5/sqoqyZve\n1x3eLK4DOPPMM1d1sJLmZ9SZv6peGL7vA+4Gzmco7QA4VGmHjT3SxjSmruu4JL/6+mPgA8BjWNoh\nTdqYy/5TgbuTvP76v62q+5I8CNyR5FrgOeDKtTtMSfN22PBX1bPAuW8ybmmHNGF+wk9qyvBLTRl+\nqSnDLzVl+KWmDL/UlOGXmjL8UlOGX2rK8EtNGX6pKcMvNWX4paYMv9SU4ZeaMvxSU4ZfampU+JOc\nkOTOJE8m2Z3kwiQnJbk/ydPD9xPX+mAlzc/YM/9ngPuq6t3MlvTajY090qSNWb33ncDvArcAVNX/\nVtWPsbFHmrQxq/duAfYDX0hyLvAQsIPJNfasX01263LwdT+AdTKBavIxl/1HAe8FPltV5wGvcsAl\nflUVB5lukuuS7Eqya//+/as9XklzMib8e4G9VbVz2L6T2ZuBjT3ShB02/FX1EvB8krOHoUuAJ7Cx\nR5q0sUWdfwLcluQY4Fngj5i9cdjYI03UqPBX1aPAtjd5ysYeaaL8hJ/UlOGXmjL8UlOGX2rK8EtN\nGX6pKcMvNWX4paYMv9SU4ZeaMvxSU4ZfasrwS00Zfqkpwy81Zfilpgy/1NSYdfvPTvLosq9Xklxv\nY480bWMW8HyqqrZW1Vbgt4D/Ae7Gxh5p0lZ62X8J8L2qeg4be6RJW2n4rwJuHx5PrLFH0nKjwz8s\n230F8PcHPmdjjzQ9KznzfxB4uKp+OGzb2CNN2ErCfzW/vOQHG3ukSRsV/iTHAduBu5YN3wRsT/I0\ncOmwLWkixjb2vAq864CxHzGhxp7ZbYl+es5aY/gJP6kpwy81Zfilpgy/1JThl5oy/FJThl9qyvBL\nTRl+qSnDLzVl+KWmDL/UlOGXmjL8UlOGX2rK8EtNGX6pqbHLeP1ZkseTPJbk9iTH2tgjTduYuq7T\ngT8FtlXVe4BNzNbvt7FHmrCxl/1HAW9PchTwDuDfsbFHmrQxXX0vAH8J/AB4EfivqvpHbOyRJm3M\nZf+JzM7yW4DfAI5Lcs3y19jYI03PmMv+S4HvV9X+qvo/Zmv3/w429kiTNib8PwAuSPKOJGG2Vv9u\nbOyRJu2wpR1VtTPJncDDwGvAI8DNwPHAHUmuBZ4DrlzLA5U0X2Mbe24Ebjxg+KdMqLFH0hv5CT+p\nKcMvNWX4paYMv9RUFlldnWQ/8Crw8sJ2uvZOxvlsZG+l+YyZy29W1agP1Cw0/ABJdlXVtoXudA05\nn43trTSfec/Fy36pKcMvNbUe4b95Hfa5lpzPxvZWms9c57Lwn/klbQxe9ktNLTT8SS5L8lSSZ5JM\natmvJGckeSDJE8N6hjuG8UmvZZhkU5JHktw7bE92PklOSHJnkieT7E5y4cTns6ZrZy4s/Ek2AX8F\nfBA4B7g6yTmL2v8cvAZ8rKrOAS4APjoc/9TXMtzB7L9ov27K8/kMcF9VvRs4l9m8JjmfhaydWVUL\n+QIuBL62bPsG4IZF7X8N5vMVYDvwFHDaMHYa8NR6H9sK5rB5+At0MXDvMDbJ+QDvBL7PcB9r2fhU\n53M68DxwErP/fXsv8IF5zmeRl/2vT+Z1e4exyUmyBJwH7GTaaxl+Gvg48PNlY1OdzxZgP/CF4ceY\nzyU5jonOpxawdqY3/FYoyfHAl4Hrq+qV5c/V7O14Ev98kuRDwL6qeuhgr5nSfJidHd8LfLaqzmP2\nMfI3XBJPaT6rXTtzjEWG/wXgjGXbm4exyUhyNLPg31ZVdw3Do9Yy3IAuAq5Isgf4EnBxki8y3fns\nBfZW1c5h+05mbwZTnc+q1s4cY5HhfxA4K8mWJMcwu3lxzwL3vyrD+oW3ALur6lPLnprkWoZVdUNV\nba6qJWZ/Ft+sqmuY7nxeAp5PcvYwdAnwBBOdD4tYO3PBNzEuB74LfA/4i/W+qbLCY38fs0us7wCP\nDl+XA+9idtPsaeDrwEnrfaxHMLf388sbfpOdD7AV2DX8Gf0DcOLE5/NJ4EngMeBvgLfNcz5+wk9q\nyht+UlOGX2rK8EtNGX6pKcMvNWX4paYMv9SU4Zea+n9suBU3eD/YWAAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x11bbb8dd0>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from gridworld import gameEnv\n",
    "\n",
    "env = gameEnv(partial=False,size=5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "deletable": true,
    "editable": true
   },
   "source": [
    "Above is an example of a starting environment in our simple game. The agent controls the blue square, and can move up, down, left, or right. The goal is to move to the green square (for +1 reward) and avoid the red square (for -1 reward). The position of the three blocks is randomized every episode."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "deletable": true,
    "editable": true
   },
   "source": [
    "### Implementing the network itself"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "class Qnetwork():\n",
    "    def __init__(self,h_size):\n",
    "        #The network recieves a frame from the game, flattened into an array.\n",
    "        #It then resizes it and processes it through four convolutional layers.\n",
    "        self.scalarInput =  tf.placeholder(shape=[None,21168],dtype=tf.float32)\n",
    "        self.imageIn = tf.reshape(self.scalarInput,shape=[-1,84,84,3])\n",
    "        self.conv1 = slim.conv2d( \\\n",
    "            inputs=self.imageIn,num_outputs=32,kernel_size=[8,8],stride=[4,4],padding='VALID', biases_initializer=None)\n",
    "        self.conv2 = slim.conv2d( \\\n",
    "            inputs=self.conv1,num_outputs=64,kernel_size=[4,4],stride=[2,2],padding='VALID', biases_initializer=None)\n",
    "        self.conv3 = slim.conv2d( \\\n",
    "            inputs=self.conv2,num_outputs=64,kernel_size=[3,3],stride=[1,1],padding='VALID', biases_initializer=None)\n",
    "        self.conv4 = slim.conv2d( \\\n",
    "            inputs=self.conv3,num_outputs=h_size,kernel_size=[7,7],stride=[1,1],padding='VALID', biases_initializer=None)\n",
    "        \n",
    "        #We take the output from the final convolutional layer and split it into separate advantage and value streams.\n",
    "        self.streamAC,self.streamVC = tf.split(self.conv4,2,3)\n",
    "        self.streamA = slim.flatten(self.streamAC)\n",
    "        self.streamV = slim.flatten(self.streamVC)\n",
    "        xavier_init = tf.contrib.layers.xavier_initializer()\n",
    "        self.AW = tf.Variable(xavier_init([h_size//2,env.actions]))\n",
    "        self.VW = tf.Variable(xavier_init([h_size//2,1]))\n",
    "        self.Advantage = tf.matmul(self.streamA,self.AW)\n",
    "        self.Value = tf.matmul(self.streamV,self.VW)\n",
    "        \n",
    "        #Then combine them together to get our final Q-values.\n",
    "        self.Qout = self.Value + tf.subtract(self.Advantage,tf.reduce_mean(self.Advantage,axis=1,keep_dims=True))\n",
    "        self.predict = tf.argmax(self.Qout,1)\n",
    "        \n",
    "        #Below we obtain the loss by taking the sum of squares difference between the target and prediction Q values.\n",
    "        self.targetQ = tf.placeholder(shape=[None],dtype=tf.float32)\n",
    "        self.actions = tf.placeholder(shape=[None],dtype=tf.int32)\n",
    "        self.actions_onehot = tf.one_hot(self.actions,env.actions,dtype=tf.float32)\n",
    "        \n",
    "        self.Q = tf.reduce_sum(tf.multiply(self.Qout, self.actions_onehot), axis=1)\n",
    "        \n",
    "        self.td_error = tf.square(self.targetQ - self.Q)\n",
    "        self.loss = tf.reduce_mean(self.td_error)\n",
    "        self.trainer = tf.train.AdamOptimizer(learning_rate=0.0001)\n",
    "        self.updateModel = self.trainer.minimize(self.loss)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "deletable": true,
    "editable": true
   },
   "source": [
    "### Experience Replay"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "deletable": true,
    "editable": true
   },
   "source": [
    "This class allows us to store experies and sample then randomly to train the network."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "class experience_buffer():\n",
    "    def __init__(self, buffer_size = 50000):\n",
    "        self.buffer = []\n",
    "        self.buffer_size = buffer_size\n",
    "    \n",
    "    def add(self,experience):\n",
    "        if len(self.buffer) + len(experience) >= self.buffer_size:\n",
    "            self.buffer[0:(len(experience)+len(self.buffer))-self.buffer_size] = []\n",
    "        self.buffer.extend(experience)\n",
    "            \n",
    "    def sample(self,size):\n",
    "        return np.reshape(np.array(random.sample(self.buffer,size)),[size,5])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "deletable": true,
    "editable": true
   },
   "source": [
    "This is a simple function to resize our game frames."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "def processState(states):\n",
    "    return np.reshape(states,[21168])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "deletable": true,
    "editable": true
   },
   "source": [
    "These functions allow us to update the parameters of our target network with those of the primary network."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "def updateTargetGraph(tfVars,tau):\n",
    "    total_vars = len(tfVars)\n",
    "    op_holder = []\n",
    "    for idx,var in enumerate(tfVars[0:total_vars//2]):\n",
    "        op_holder.append(tfVars[idx+total_vars//2].assign((var.value()*tau) + ((1-tau)*tfVars[idx+total_vars//2].value())))\n",
    "    return op_holder\n",
    "\n",
    "def updateTarget(op_holder,sess):\n",
    "    for op in op_holder:\n",
    "        sess.run(op)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "deletable": true,
    "editable": true
   },
   "source": [
    "### Training the network"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "deletable": true,
    "editable": true
   },
   "source": [
    "Setting all the training parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "batch_size = 32 #How many experiences to use for each training step.\n",
    "update_freq = 4 #How often to perform a training step.\n",
    "y = .99 #Discount factor on the target Q-values\n",
    "startE = 1 #Starting chance of random action\n",
    "endE = 0.1 #Final chance of random action\n",
    "annealing_steps = 10000. #How many steps of training to reduce startE to endE.\n",
    "num_episodes = 10000 #How many episodes of game environment to train network with.\n",
    "pre_train_steps = 10000 #How many steps of random actions before training begins.\n",
    "max_epLength = 50 #The max allowed length of our episode.\n",
    "load_model = False #Whether to load a saved model.\n",
    "path = \"./dqn\" #The path to save our model to.\n",
    "h_size = 512 #The size of the final convolutional layer before splitting it into Advantage and Value streams.\n",
    "tau = 0.001 #Rate to update target network toward primary network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true,
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saved Model\n",
      "(500, 2.2999999999999998, 1)\n",
      "(1000, 1.8, 1)\n",
      "(1500, 1.6000000000000001, 1)\n",
      "(2000, 2.6000000000000001, 1)\n",
      "(2500, 1.8999999999999999, 1)\n",
      "(3000, 2.0, 1)\n",
      "(3500, 1.3, 1)\n",
      "(4000, 4.0999999999999996, 1)\n",
      "(4500, 2.6000000000000001, 1)\n",
      "(5000, 2.8999999999999999, 1)\n",
      "(5500, 2.5, 1)\n",
      "(6000, 1.3, 1)\n",
      "(6500, 2.6000000000000001, 1)\n",
      "(7000, 1.8, 1)\n",
      "(7500, 1.8, 1)\n",
      "(8000, 1.8999999999999999, 1)\n",
      "(8500, 2.0, 1)\n",
      "(9000, 2.1000000000000001, 1)\n",
      "(9500, 1.8, 1)\n",
      "(10000, 1.8999999999999999, 1)\n",
      "(10500, 2.5, 0.9549999999999828)\n",
      "(11000, 2.3999999999999999, 0.9099999999999655)\n",
      "(11500, 0.69999999999999996, 0.8649999999999483)\n"
     ]
    }
   ],
   "source": [
    "tf.reset_default_graph()\n",
    "mainQN = Qnetwork(h_size)\n",
    "targetQN = Qnetwork(h_size)\n",
    "\n",
    "init = tf.global_variables_initializer()\n",
    "\n",
    "saver = tf.train.Saver()\n",
    "\n",
    "trainables = tf.trainable_variables()\n",
    "\n",
    "targetOps = updateTargetGraph(trainables,tau)\n",
    "\n",
    "myBuffer = experience_buffer()\n",
    "\n",
    "#Set the rate of random action decrease. \n",
    "e = startE\n",
    "stepDrop = (startE - endE)/annealing_steps\n",
    "\n",
    "#create lists to contain total rewards and steps per episode\n",
    "jList = []\n",
    "rList = []\n",
    "total_steps = 0\n",
    "\n",
    "#Make a path for our model to be saved in.\n",
    "if not os.path.exists(path):\n",
    "    os.makedirs(path)\n",
    "\n",
    "with tf.Session() as sess:\n",
    "    sess.run(init)\n",
    "    if load_model == True:\n",
    "        print('Loading Model...')\n",
    "        ckpt = tf.train.get_checkpoint_state(path)\n",
    "        saver.restore(sess,ckpt.model_checkpoint_path)\n",
    "    for i in range(num_episodes):\n",
    "        episodeBuffer = experience_buffer()\n",
    "        #Reset environment and get first new observation\n",
    "        s = env.reset()\n",
    "        s = processState(s)\n",
    "        d = False\n",
    "        rAll = 0\n",
    "        j = 0\n",
    "        #The Q-Network\n",
    "        while j < max_epLength: #If the agent takes longer than 200 moves to reach either of the blocks, end the trial.\n",
    "            j+=1\n",
    "            #Choose an action by greedily (with e chance of random action) from the Q-network\n",
    "            if np.random.rand(1) < e or total_steps < pre_train_steps:\n",
    "                a = np.random.randint(0,4)\n",
    "            else:\n",
    "                a = sess.run(mainQN.predict,feed_dict={mainQN.scalarInput:[s]})[0]\n",
    "            s1,r,d = env.step(a)\n",
    "            s1 = processState(s1)\n",
    "            total_steps += 1\n",
    "            episodeBuffer.add(np.reshape(np.array([s,a,r,s1,d]),[1,5])) #Save the experience to our episode buffer.\n",
    "            \n",
    "            if total_steps > pre_train_steps:\n",
    "                if e > endE:\n",
    "                    e -= stepDrop\n",
    "                \n",
    "                if total_steps % (update_freq) == 0:\n",
    "                    trainBatch = myBuffer.sample(batch_size) #Get a random batch of experiences.\n",
    "                    #Below we perform the Double-DQN update to the target Q-values\n",
    "                    Q1 = sess.run(mainQN.predict,feed_dict={mainQN.scalarInput:np.vstack(trainBatch[:,3])})\n",
    "                    Q2 = sess.run(targetQN.Qout,feed_dict={targetQN.scalarInput:np.vstack(trainBatch[:,3])})\n",
    "                    end_multiplier = -(trainBatch[:,4] - 1)\n",
    "                    doubleQ = Q2[range(batch_size),Q1]\n",
    "                    targetQ = trainBatch[:,2] + (y*doubleQ * end_multiplier)\n",
    "                    #Update the network with our target values.\n",
    "                    _ = sess.run(mainQN.updateModel, \\\n",
    "                        feed_dict={mainQN.scalarInput:np.vstack(trainBatch[:,0]),mainQN.targetQ:targetQ, mainQN.actions:trainBatch[:,1]})\n",
    "                    \n",
    "                    updateTarget(targetOps,sess) #Update the target network toward the primary network.\n",
    "            rAll += r\n",
    "            s = s1\n",
    "            \n",
    "            if d == True:\n",
    "\n",
    "                break\n",
    "        \n",
    "        myBuffer.add(episodeBuffer.buffer)\n",
    "        jList.append(j)\n",
    "        rList.append(rAll)\n",
    "        #Periodically save the model. \n",
    "        if i % 1000 == 0:\n",
    "            saver.save(sess,path+'/model-'+str(i)+'.ckpt')\n",
    "            print(\"Saved Model\")\n",
    "        if len(rList) % 10 == 0:\n",
    "            print(total_steps,np.mean(rList[-10:]), e)\n",
    "    saver.save(sess,path+'/model-'+str(i)+'.ckpt')\n",
    "print(\"Percent of succesful episodes: \" + str(sum(rList)/num_episodes) + \"%\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "deletable": true,
    "editable": true
   },
   "source": [
    "### Checking network learning"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "deletable": true,
    "editable": true
   },
   "source": [
    "Mean reward over time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "rMat = np.resize(np.array(rList),[len(rList)//100,100])\n",
    "rMean = np.average(rMat,1)\n",
    "plt.plot(rMean)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python [conda env:py2]",
   "language": "python",
   "name": "conda-env-py2-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
